package com.ruoyi.gateway.filter;

import org.springframework.util.StreamUtils;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Filter请求拦截中添加header属性
 * @author: songzy
 */
public class CustomizeRequestWrapper extends HttpServletRequestWrapper {

    /** 用于将流保存下来 */
    private byte[] requestBody = null;

    /** 原 ServletRequest */
    private HttpServletRequest deafultRequest = null;

    /** 请求header */
    private Map<String, String> headerMap = new HashMap<String, String>();


    public CustomizeRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        deafultRequest = request ;
        requestBody = StreamUtils.copyToByteArray(request.getInputStream());
    }

    public HttpServletRequest getDefaultHttpServletRequest() {
       return (HttpServletRequest) super.getRequest();
    }

    public void addHeader(String name, String value) {
        headerMap.put(name, value);
    }

    @Override
    public String getHeader(String name) {
        String headerValue = super.getHeader(name);
        if (headerMap.containsKey(name)) {
            headerValue = headerMap.get(name);
        }
        return headerValue;
    }

    @Override
    public Enumeration<String> getHeaderNames() {
        List<String> names = Collections.list(super.getHeaderNames());
        for (String name : headerMap.keySet()) {
            names.add(name);
        }
        return Collections.enumeration(names);
    }

    @Override
    public Enumeration<String> getHeaders(String name) {
        List<String> values = Collections.list(super.getHeaders(name));
        if (headerMap.containsKey(name)) {
            values.add(headerMap.get(name));
        }
        return Collections.enumeration(values);
    }

    @Override
    public ServletInputStream getInputStream() {
        final ByteArrayInputStream bais = new ByteArrayInputStream(requestBody);
        return new ServletInputStream() {
            @Override
            public int read(){
                return bais.read();  // 读取 requestBody 中的数据
            }
            @Override
            public boolean isFinished() {
                return false;
            }
            @Override
            public boolean isReady() {
                return false;
            }
            @Override
            public void setReadListener(ReadListener readListener) { }
        };
    }

    @Override
    public BufferedReader getReader() throws IOException{
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }
}

